{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "7317c8a8",
   "metadata": {},
   "source": [
    "# Train a model to schedule events using GRPO\n",
    "\n",
    "➡️ For a complete walk-through, read [this blog post](https://huggingface.co/blog/anakin87/qwen-scheduler-grpo).\n",
    "\n",
    "Once trained, the model should be able to solve problems like the following.\n",
    "\n",
    "**Example input**\n",
    "\n",
    "Task: create an optimized schedule based on the given events. Maximize the total weighted duration of the events.\n",
    "*(For the detailed prompt, see below).*\n",
    "\n",
    "Events:\n",
    "- Event A (01:27 - 01:42)\n",
    "- Event B (01:15 - 02:30)\n",
    "- Event C (15:43 - 17:43)\n",
    "\n",
    "Priorities:\n",
    "- Event B\n",
    "\n",
    "**Example output**\n",
    "\n",
    "```xml\n",
    "<think>A detailed reasoning</think>\n",
    "<schedule>\n",
    "<event>\n",
    "<name>Event B</name>\n",
    "<start>01:15</start>\n",
    "<end>02:30</end>\n",
    "</event>\n",
    "<event>\n",
    "<name>Event C</name>\n",
    "<start>15:43</start>\n",
    "<end>17:43</end>\n",
    "</event>\n",
    "</schedule>\n",
    "```\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49931a13-0efd-4f04-b3dd-e9dedf8659e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "! pip install --upgrade pip\n",
    "! pip install \"unsloth==2025.3.19\" vllm wandb\n",
    "! pip uninstall -y typing_extensions &&  pip install typing_extensions==4.11.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "d1e13364-e2b7-4620-8a1a-9bcb10980b0a",
   "metadata": {},
   "outputs": [],
   "source": [
    "! rm -rf outputs completion_samples"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "97af56ba",
   "metadata": {},
   "source": [
    "## Load the original model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ac15f967-f4ca-47e7-b9b9-da2fadc4f0b6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.\n",
      "🦥 Unsloth Zoo will now patch everything to make training faster!\n",
      "INFO 04-04 12:31:12 [__init__.py:239] Automatically detected platform cuda.\n",
      "==((====))==  Unsloth 2025.3.19: Fast Qwen2 patching. Transformers: 4.50.3. vLLM: 0.8.2.\n",
      "   \\\\   /|    NVIDIA RTX A6000. Num GPUs = 1. Max memory: 47.438 GB. Platform: Linux.\n",
      "O^O/ \\_/ \\    Torch: 2.6.0+cu124. CUDA: 8.6. CUDA Toolkit: 12.4. Triton: 3.2.0\n",
      "\\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.29.post2. FA2 = False]\n",
      " \"-____-\"     Free license: http://github.com/unslothai/unsloth\n",
      "Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!\n",
      "Unsloth: vLLM loading unsloth/qwen2.5-coder-7b-instruct-bnb-4bit with actual GPU utilization = 84.46%\n",
      "Unsloth: Your GPU has CUDA compute capability 8.6 with VRAM = 47.44 GB.\n",
      "Unsloth: Using conservativeness = 1.0. Chunked prefill tokens = 2048. Num Sequences = 320.\n",
      "Unsloth: vLLM's KV Cache can use up to 34.09 GB. Also swap space = 6 GB.\n",
      "INFO 04-04 12:31:25 [config.py:585] This model supports multiple tasks: {'reward', 'generate', 'classify', 'embed', 'score'}. Defaulting to 'generate'.\n",
      "WARNING 04-04 12:31:25 [arg_utils.py:1854] --quantization bitsandbytes is not supported by the V1 Engine. Falling back to V0. \n",
      "Unsloth: vLLM Bitsandbytes config using kwargs = {'load_in_8bit': False, 'load_in_4bit': True, 'bnb_4bit_compute_dtype': 'bfloat16', 'bnb_4bit_quant_storage': 'uint8', 'bnb_4bit_quant_type': 'nf4', 'bnb_4bit_use_double_quant': True, 'llm_int8_enable_fp32_cpu_offload': False, 'llm_int8_has_fp16_weight': False, 'llm_int8_skip_modules': [], 'llm_int8_threshold': 6.0}\n",
      "INFO 04-04 12:31:25 [llm_engine.py:241] Initializing a V0 LLM engine (v0.8.2) with config: model='unsloth/qwen2.5-coder-7b-instruct-bnb-4bit', speculative_config=None, tokenizer='unsloth/qwen2.5-coder-7b-instruct-bnb-4bit', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=2048, download_dir=None, load_format=LoadFormat.BITSANDBYTES, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=bitsandbytes, enforce_eager=False, kv_cache_dtype=auto,  device_config=cuda:0, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar', reasoning_backend=None), observability_config=ObservabilityConfig(show_hidden_metrics=False, otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=unsloth/qwen2.5-coder-7b-instruct-bnb-4bit, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=True, chunked_prefill_enabled=False, use_async_output_proc=True, disable_mm_preprocessor_cache=False, mm_processor_kwargs=None, pooler_config=None, compilation_config={\"level\":0,\"splitting_ops\":[],\"compile_sizes\":[],\"cudagraph_capture_sizes\":[320,312,304,296,288,280,272,264,256,248,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1],\"max_capture_size\":320}, use_cached_outputs=False, \n",
      "INFO 04-04 12:31:26 [cuda.py:291] Using Flash Attention backend.\n",
      "INFO 04-04 12:31:26 [parallel_state.py:954] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0\n",
      "INFO 04-04 12:31:26 [model_runner.py:1110] Starting to load model unsloth/qwen2.5-coder-7b-instruct-bnb-4bit...\n",
      "INFO 04-04 12:31:26 [loader.py:1155] Loading weights with BitsAndBytes quantization. May take a while ...\n",
      "INFO 04-04 12:31:28 [weight_utils.py:265] Using model weights format ['*.safetensors']\n",
      "INFO 04-04 12:31:28 [weight_utils.py:315] No model.safetensors.index.json found in remote.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "613e9f2eef4548128ada832f878bd226",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c4d917aeaf8e4923a7f7a50d5d690dbd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO 04-04 12:31:32 [punica_selector.py:18] Using PunicaWrapperGPU.\n",
      "INFO 04-04 12:31:32 [model_runner.py:1146] Model loading took 5.3638 GB and 5.440341 seconds\n",
      "INFO 04-04 12:31:34 [worker.py:267] Memory profiling takes 2.03 seconds\n",
      "INFO 04-04 12:31:34 [worker.py:267] the current vLLM instance can use total_gpu_memory (47.44GiB) x gpu_memory_utilization (0.84) = 40.07GiB\n",
      "INFO 04-04 12:31:34 [worker.py:267] model weights take 5.36GiB; non_torch_memory takes 0.06GiB; PyTorch activation peak memory takes 1.76GiB; the rest of the memory reserved for KV Cache is 32.89GiB.\n",
      "INFO 04-04 12:31:35 [executor_base.py:111] # cuda blocks: 38487, # CPU blocks: 7021\n",
      "INFO 04-04 12:31:35 [executor_base.py:116] Maximum concurrency for 2048 tokens per request: 300.68x\n",
      "INFO 04-04 12:31:40 [model_runner.py:1442] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error occurs during cudagraph capture, consider decreasing `gpu_memory_utilization` or switching to eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Capturing CUDA graph shapes: 100%|██████████| 43/43 [00:36<00:00,  1.19it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO 04-04 12:32:16 [model_runner.py:1570] Graph capturing finished in 36 secs, took 0.85 GiB\n",
      "INFO 04-04 12:32:16 [llm_engine.py:447] init engine (profile, create kv cache, warmup model) took 44.60 seconds\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "Unsloth 2025.3.19 patched 28 layers with 28 QKV layers, 28 O layers and 28 MLP layers.\n"
     ]
    }
   ],
   "source": [
    "from unsloth import FastLanguageModel\n",
    "\n",
    "max_seq_length = 2048  # Can increase for longer reasoning traces\n",
    "lora_rank = 32  # Larger rank = smarter, but slower\n",
    "\n",
    "model, tokenizer = FastLanguageModel.from_pretrained(\n",
    "    model_name=\"Qwen/Qwen2.5-Coder-7B-Instruct\",\n",
    "    max_seq_length=max_seq_length,\n",
    "    load_in_4bit=True,  # False for LoRA 16bit\n",
    "    fast_inference=True,  # Enable vLLM fast inference\n",
    "    max_lora_rank=lora_rank,\n",
    "    gpu_memory_utilization=0.85,  # Reduce if out of memory\n",
    ")\n",
    "\n",
    "model = FastLanguageModel.get_peft_model(\n",
    "    model,\n",
    "    r=lora_rank,  # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128\n",
    "    target_modules=[\n",
    "        \"q_proj\",\n",
    "        \"k_proj\",\n",
    "        \"v_proj\",\n",
    "        \"o_proj\",\n",
    "        \"gate_proj\",\n",
    "        \"up_proj\",\n",
    "        \"down_proj\",\n",
    "    ],  # Remove QKVO if out of memory\n",
    "    lora_alpha=lora_rank,\n",
    "    use_gradient_checkpointing=\"unsloth\",  # Enable long context finetuning\n",
    "    random_state=3407,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c18f7119",
   "metadata": {},
   "source": [
    "## Dataset preparation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "dc2c3319-a946-4999-b47c-655b0fa21ce1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3032fa9e43474b62a8c505817704b962",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/500 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "{'events': [['Backstage tour VIP pass', '01:06', '01:21'],\n",
       "  ['Sum 41 concert', '14:57', '15:27'],\n",
       "  ['Korean BBQ food trucks', '16:21', '17:51'],\n",
       "  ['Fireworks', '08:48', '10:48'],\n",
       "  ['Jam session', '14:31', '15:31']],\n",
       " 'priority_events': ['Sum 41 concert'],\n",
       " 'optimal_score': 285,\n",
       " 'prompt': [{'content': 'You are a highly precise event scheduler. Your goal is to create an optimized schedule following strict constraints.\\n1. First, generate a detailed reasoning process inside <think> and </think> tags.  \\n2. Then, generate the schedule inside <schedule> and </schedule> tags.',\n",
       "   'role': 'system'},\n",
       "  {'content': \"Task: create an optimized schedule based on the given events.\\n\\nRules:\\n- The schedule MUST be in chronological order.\\n- Event start and end times are ABSOLUTE. Do NOT change, shorten, adjust, or split them.\\n- Priority events (weight = 2) are more important than normal events (weight = 1).\\n- Maximize the sum of weighted event durations.\\n- No two events can overlap. If two events conflict, keep the one that maximizes total weighted time. \\n- It's acceptable to exclude some events if necessary to follow the above constraints.\\n\\nAlways follow this format:  \\n\\n<think>...</think>\\n<schedule>\\n<event>\\n<name>...</name>\\n<start>...</start>\\n<end>...</end>\\n</event>\\n...\\n</schedule>\\n\\n---\\n\\nEvents:\\n- Backstage tour VIP pass (01:06 - 01:21)\\n- Sum 41 concert (14:57 - 15:27)\\n- Korean BBQ food trucks (16:21 - 17:51)\\n- Fireworks (08:48 - 10:48)\\n- Jam session (14:31 - 15:31)\\n\\nPriorities:\\n- Sum 41 concert\",\n",
       "   'role': 'user'}]}"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import datasets\n",
    "\n",
    "SYSTEM_PROMPT = \"\"\"You are a precise event scheduler.\n",
    "1. First, reason through the problem inside <think> and </think> tags. Here you can create drafts, compare alternatives, and check for mistakes.\n",
    "2. When confident, output the final schedule inside <schedule> and </schedule> tags. Your schedule must strictly follow the rules provided by the user.\"\"\"\n",
    "\n",
    "USER_PROMPT = \"\"\"Task: create an optimized schedule based on the given events.\n",
    "\n",
    "Rules:\n",
    "- The schedule MUST be in strict chronological order. Do NOT place priority events earlier unless their actual start time is earlier.\n",
    "- Event start and end times are ABSOLUTE. NEVER change, shorten, adjust, or split them.\n",
    "- Priority events (weight = 2) carry more weight than normal events (weight = 1), but they MUST still respect chronological order.\n",
    "- Maximize the sum of weighted event durations.\n",
    "- No overlaps allowed. In conflicts, include the event with the higher weighted time.\n",
    "- Some events may be excluded if needed to meet these rules.\n",
    "\n",
    "\n",
    "You must use this format:  \n",
    "\n",
    "<think>...</think>\n",
    "<schedule>\n",
    "<event>\n",
    "<name>...</name>\n",
    "<start>...</start>\n",
    "<end>...</end>\n",
    "</event>\n",
    "...\n",
    "</schedule>\n",
    "\n",
    "---\n",
    "\n",
    "\"\"\"\n",
    "\n",
    "ds = datasets.load_dataset(\"anakin87/events-scheduling\", split=\"train\")\n",
    "ds\n",
    "\n",
    "ds = ds.map(\n",
    "    lambda x: {\n",
    "        \"prompt\": [\n",
    "            {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
    "            {\"role\": \"user\", \"content\": USER_PROMPT + x[\"prompt\"]},\n",
    "        ]\n",
    "    }\n",
    ")\n",
    "\n",
    "\n",
    "ds[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "276158e8",
   "metadata": {},
   "source": [
    "## Reward functions\n",
    "\n",
    "We use 3 reward functions:\n",
    "1. Format reward: ensure the output is in the correct format. (10 points)\n",
    "2. Sorted events reward: ensure the events are sorted in chronological order. (20 points)\n",
    "3. Score reward: ratio between the total weighted duration of the events and the optimal score computed with dynamic programming. (70 points)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e6f816bd-9763-4826-b922-e08404e8b414",
   "metadata": {},
   "outputs": [],
   "source": [
    "import random, os\n",
    "from datetime import datetime\n",
    "import re\n",
    "\n",
    "\n",
    "def minutes_to_time(minutes):\n",
    "    \"\"\"Convert minutes since midnight to time string.\n",
    "\n",
    "    Args:\n",
    "        minutes (int): Minutes since midnight\n",
    "\n",
    "    Returns:\n",
    "        str: Time string in \"HH:MM\" format\n",
    "    \"\"\"\n",
    "    return f\"{minutes // 60:02d}:{minutes % 60:02d}\"\n",
    "\n",
    "\n",
    "def time_to_minutes(time_str):\n",
    "    \"\"\"Convert time string to minutes since midnight.\n",
    "\n",
    "    Args:\n",
    "        time_str (str): Time string in \"HH:MM\" format\n",
    "\n",
    "    Returns:\n",
    "        int: Minutes since midnight\n",
    "    \"\"\"\n",
    "    hours, mins = map(int, time_str.split(\":\"))\n",
    "    return hours * 60 + mins\n",
    "\n",
    "\n",
    "overall_pattern = r\"<think>.+</think>.*<schedule>.*(<event>.*<name>.+</name>.*<start>\\d{2}:\\d{2}</start>.*<end>\\d{2}:\\d{2}</end>.*</event>)+.*</schedule>\"\n",
    "overall_regex = re.compile(overall_pattern, re.DOTALL)\n",
    "\n",
    "capture_pattern = r\"\"\"\n",
    "    <event>\\s*\n",
    "        <name>([^<]+)</name>\\s*\n",
    "        <start>(\\d{2}:\\d{2})</start>\\s*\n",
    "        <end>(\\d{2}:\\d{2})</end>\\s*\n",
    "    </event>\n",
    "\"\"\"\n",
    "\n",
    "capture_regex = re.compile(capture_pattern, re.VERBOSE)\n",
    "\n",
    "\n",
    "def get_events(content):\n",
    "    \"\"\"Extract event information from XML-like content.\n",
    "\n",
    "    Args:\n",
    "        content (str): XML-like string containing event data\n",
    "\n",
    "    Returns:\n",
    "        list: List of tuples (name, start_time, end_time)\n",
    "    \"\"\"\n",
    "    return [\n",
    "        (match.group(1), match.group(2), match.group(3))\n",
    "        for match in capture_regex.finditer(content)\n",
    "    ]\n",
    "\n",
    "\n",
    "def format_reward(prompts, completions, **kwargs):\n",
    "    responses = [completion[0][\"content\"] for completion in completions]\n",
    "\n",
    "    return [\n",
    "        0.0 if not overall_regex.match(response) else 10.0 for response in responses\n",
    "    ]\n",
    "\n",
    "\n",
    "def score_reward(\n",
    "    prompts, completions, events, priority_events, optimal_score, **kwargs\n",
    "):\n",
    "    scores = []\n",
    "    responses = [completion[0][\"content\"] for completion in completions]\n",
    "\n",
    "    for content, valid_events, priorities, opt_score in zip(\n",
    "        responses, events, priority_events, optimal_score\n",
    "    ):\n",
    "        scheduled_events = get_events(content)\n",
    "\n",
    "        # Get valid scheduled events\n",
    "        existing_events = {\n",
    "            ev for ev in scheduled_events if [ev[0], ev[1], ev[2]] in valid_events\n",
    "        }\n",
    "\n",
    "        # penalize choosing nonexistent events or less than 2 events (not a valid schedule)\n",
    "        if len(existing_events) < len(scheduled_events) or len(existing_events) < 2:\n",
    "            scores.append(0.0)\n",
    "            continue\n",
    "\n",
    "        # Convert to minutes\n",
    "        existing_events_minutes = [\n",
    "            (ev[0], time_to_minutes(ev[1]), time_to_minutes(ev[2]))\n",
    "            for ev in existing_events\n",
    "        ]\n",
    "\n",
    "        # remove overlapping events and remove both events - to penalize overlaps\n",
    "        overlapping_events = set()\n",
    "        for j in range(len(existing_events_minutes)):\n",
    "            for k in range(j + 1, len(existing_events_minutes)):\n",
    "                if (\n",
    "                    existing_events_minutes[j][1] <= existing_events_minutes[k][2]\n",
    "                    and existing_events_minutes[j][2] >= existing_events_minutes[k][1]\n",
    "                ):\n",
    "                    overlapping_events.add(existing_events_minutes[j])\n",
    "                    overlapping_events.add(existing_events_minutes[k])\n",
    "\n",
    "        existing_events_minutes = [\n",
    "            ev for ev in existing_events_minutes if ev not in overlapping_events\n",
    "        ]\n",
    "\n",
    "        # Calculate score\n",
    "        score = sum(\n",
    "            2 * (ev[2] - ev[1]) if ev[0] in priorities else ev[2] - ev[1]\n",
    "            for ev in existing_events_minutes\n",
    "        )\n",
    "\n",
    "        scores.append((score / opt_score) * 70)\n",
    "\n",
    "    # Log samples\n",
    "    if any(score > 0 for score in scores) and random.random() < 0.10:\n",
    "        os.makedirs(\"completion_samples\", exist_ok=True)\n",
    "        log_file = os.path.join(\"completion_samples\", \"completion_samples.txt\")\n",
    "        with open(log_file, \"a\") as f:\n",
    "            f.write(\"\\n\\n==============\\n\")\n",
    "            f.write(f\"\\n{datetime.now().time()}\\n\")\n",
    "            f.write(f\"{prompts[0]}\\n\")\n",
    "            f.write(f\"{scores}\\n\")\n",
    "            f.write(f\"{completions}\")\n",
    "\n",
    "    return scores\n",
    "\n",
    "\n",
    "def sorted_events_reward(completions, **kwargs):\n",
    "    scores = []\n",
    "    responses = [completion[0][\"content\"] for completion in completions]\n",
    "\n",
    "    for response in responses:\n",
    "        scheduled_events = get_events(response)\n",
    "\n",
    "        # not a valid schedule: should be discarded\n",
    "        if len(scheduled_events) < 2:\n",
    "            scores.append(0.0)\n",
    "            continue\n",
    "\n",
    "        scheduled_events_minutes = [\n",
    "            (ev[0], time_to_minutes(ev[1]), time_to_minutes(ev[2]))\n",
    "            for ev in scheduled_events\n",
    "        ]\n",
    "\n",
    "        if all(\n",
    "            scheduled_events_minutes[i][1] < scheduled_events_minutes[i + 1][1]\n",
    "            for i in range(len(scheduled_events_minutes) - 1)\n",
    "        ):\n",
    "            scores.append(20.0)\n",
    "        else:\n",
    "            scores.append(0)\n",
    "\n",
    "    return scores"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c5422b65",
   "metadata": {},
   "source": [
    "## Training configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8fff0ee2",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenized_prompts = [\n",
    "    tokenizer.apply_chat_template(prompt, tokenize=True, add_generation_prompt=True)\n",
    "    for prompt in ds[\"prompt\"]\n",
    "]\n",
    "exact_max_prompt_length = max(\n",
    "    [len(tokenized_prompt) for tokenized_prompt in tokenized_prompts]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "4e7aaa51-6574-40e0-a527-150cc7769648",
   "metadata": {},
   "outputs": [],
   "source": [
    "max_prompt_length = 448\n",
    "\n",
    "new_model_id = \"anakin87/qwen-scheduler-7b-grpo\"\n",
    "\n",
    "\n",
    "from trl import GRPOConfig, GRPOTrainer\n",
    "\n",
    "training_args = GRPOConfig(\n",
    "    learning_rate=8e-6,\n",
    "    adam_beta1=0.9,\n",
    "    adam_beta2=0.99,\n",
    "    weight_decay=0.1,\n",
    "    warmup_ratio=0.01,\n",
    "    lr_scheduler_type=\"cosine\",\n",
    "    optim=\"paged_adamw_8bit\",\n",
    "    logging_steps=1,\n",
    "    per_device_train_batch_size=8,\n",
    "    gradient_accumulation_steps=1,\n",
    "    num_generations=8,  # Decrease if out of memory\n",
    "    max_prompt_length=max_prompt_length,\n",
    "    max_completion_length=max_seq_length - max_prompt_length,\n",
    "    max_grad_norm=0.1,\n",
    "    report_to=\"wandb\",\n",
    "    output_dir=\"outputs\",\n",
    "    overwrite_output_dir=True,\n",
    "    push_to_hub=True,\n",
    "    hub_model_id=new_model_id,\n",
    "    hub_strategy=\"every_save\",\n",
    "    save_strategy=\"steps\",\n",
    "    save_steps=50,\n",
    "    save_total_limit=1,\n",
    "    num_train_epochs=3,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a46731c-cf7a-4d8d-bce6-d7a0d9e00cf8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import wandb\n",
    "\n",
    "wandb.init(project=\"GRPO-reboost\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "60ebd63b",
   "metadata": {},
   "source": [
    "## Training!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8708523b-5e99-408e-b8b4-b90ffd019b43",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = GRPOTrainer(\n",
    "    model=model,\n",
    "    processing_class=tokenizer,\n",
    "    reward_funcs=[\n",
    "        format_reward,\n",
    "        sorted_events_reward,\n",
    "        score_reward,\n",
    "    ],\n",
    "    args=training_args,\n",
    "    train_dataset=ds,\n",
    ")\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0829576b",
   "metadata": {},
   "source": [
    "[Weights & Biases report](https://wandb.ai/stefanofiorucci/GRPO-reboost/reports/Qwen-Scheduler-GRPO--VmlldzoxMjI1MTA4MA?accessToken=p9whiiwc1ourpt1ae5hcs84un0ri117ty84m3c56kogvkm5drp5hnk9tanvlvrsn)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
